import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler

from loss import weighted_mae
from utils import add_parser


class Unet(pl.LightningModule):
    def __init__(self, height, width, input_length, target_length, downscale_factor, learning_rate, loss_fx,
                 input_channels, predict_channels, weights_prec, thresholds_prec,
                 train_log_steps, val_log_steps, test_save_path, lr_scheduler_gamma,
                 first_channel=16, scheduler='reducelr', trilinear=False, **kwargs):

        super(Unet, self).__init__()
        self.save_hyperparameters()
        self.num_channels_in = input_channels
        self.num_channels_out = predict_channels

        self.loss_fx = weighted_mae(weights_prec=self.hparams.weights_prec,
                                       thresholds_prec=self.hparams.thresholds_prec)


        self.inc = DoubleConv(self.num_channels_in * input_length, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.up1 = Up(512, 256)
        self.up2 = Up(256, 128)
        self.up3 = Up(128, 64)
        self.outc = OutConv(64, self.num_channels_out * target_length)

    def forward(self, x):
        '''
        :param x: [nbatchs, ninput_seqs, nchannel, nheight, nwidth]
        :return:
        '''
        nbatchs, _, _, nheight, nwidth = x.shape
        x = x.reshape(nbatchs, -1, nheight, nwidth)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        out = self.outc(x)
        return out.reshape(nbatchs, self.hparams.target_length, self.num_channels_out, nheight, nwidth)

    def training_step(self, batch, batch_idx):
        seqs_x, seqs_y = batch
        decoder_frames = self(seqs_x)
        decoder_frames = torch.clip(decoder_frames, 0, 1)
        loss = self.loss_fx(decoder_frames, seqs_y)
        return loss

    def validation_step(self, batch, batch_idx):
        seqs_x, seqs_y = batch
        decoder_frames = self(seqs_x)
        decoder_frames = torch.clip(decoder_frames, 0, 1)
        valid_loss_fx = self.loss_fx(decoder_frames, seqs_y)
        metrics_pred = {'valid_loss_fx': valid_loss_fx.item()}
        return metrics_pred


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-8)

        if self.hparams.scheduler == 'exp':
            scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=self.hparams.lr_scheduler_gamma)
            return [optimizer], [scheduler]
        elif self.hparams.scheduler == 'cosine':
            scheduler = lr_scheduler.CosineAnnealing(optimizer, T_max=10)
            return [optimizer], [scheduler]
        elif self.hparams.scheduler == 'reducelr':
            return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8,
                                                                patience=5, verbose=True),
                    'monitor': 'epoch_val_loss',
                }
            }

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group('Unet3d')
        add_parser(parser)
        parser.add_argument('--lr_scheduler_gamma', type=float, default=0.98)
        parser.add_argument('--first_channel', type=int, default=16)
        parser.add_argument('--scheduler', type=str, default="reducelr")
        parser.add_argument('--trilinear', type=int, default=0)
        parser.add_argument('--weights_prec', nargs='+', type=float, default=[1, 1, 2.5, 5.0, 10, 20])
        parser.add_argument('--thresholds_prec', nargs='+', type=float, default=[0.00167, 0.0167, 0.083, 0.167, 0.333])
        return parent_parser

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.GroupNorm(16, mid_channels),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(16, out_channels),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with conv(stride = 2) then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.pool_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=2, padding=1),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.pool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=False):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return torch.sigmoid(self.conv(x))

